{
  "cells": [
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "LqNpENf-ec0X",
        "slideshow": {
          "slide_type": "slide"
        }
      },
      "outputs": [],
      "source": [
        "!pip install -U tf-nightly"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "Pa2qpEmoVOGe",
        "slideshow": {
          "slide_type": "-"
        }
      },
      "outputs": [],
      "source": [
        "import os\n",
        "import time\n",
        "\n",
        "import tensorflow as tf\n",
        "from tensorflow.contrib import autograph\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "import numpy as np\n",
        "import six\n",
        "\n",
        "from google.colab import widgets"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "HNqUFL4deCsL",
        "slideshow": {
          "slide_type": "slide"
        }
      },
      "source": [
        "# Case study: training a custom RNN, using Keras and Estimators\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "YkC1k4HEQ7rw",
        "slideshow": {
          "slide_type": "-"
        }
      },
      "source": [
        "In this section, we show how you can use AutoGraph to build RNNColorbot, an RNN that takes as input names of colors and predicts their corresponding RGB tuples. The model will be trained by a [custom Estimator](https://www.tensorflow.org/get_started/custom_estimators)."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "7nkPDl5CTCNb",
        "slideshow": {
          "slide_type": "-"
        }
      },
      "source": [
        "To get started, set up the dataset. The following cells defines methods that download and format the data needed for RNNColorbot; the details aren't important (read them in the privacy of your own home if you so wish), but make sure to run the cells before proceeding."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "A0uREmVXCQEw",
        "slideshow": {
          "slide_type": "-"
        }
      },
      "outputs": [],
      "source": [
        "def parse(line):\n",
        "  \"\"\"Parses a line from the colors dataset.\"\"\"\n",
        "  items = tf.string_split([line], \",\").values\n",
        "  rgb = tf.string_to_number(items[1:], out_type=tf.float32) / 255.0\n",
        "  color_name = items[0]\n",
        "  chars = tf.one_hot(tf.decode_raw(color_name, tf.uint8), depth=256)\n",
        "  length = tf.cast(tf.shape(chars)[0], dtype=tf.int64)\n",
        "  return rgb, chars, length\n",
        "\n",
        "\n",
        "def set_static_batch_shape(batch_size):\n",
        "  def apply(rgb, chars, length):\n",
        "    rgb.set_shape((batch_size, None))\n",
        "    chars.set_shape((batch_size, None, 256))\n",
        "    length.set_shape((batch_size,))\n",
        "    return rgb, chars, length\n",
        "  return apply\n",
        "\n",
        "\n",
        "def load_dataset(data_dir, url, batch_size, training=True):\n",
        "  \"\"\"Loads the colors data at path into a tf.PaddedDataset.\"\"\"\n",
        "  path = tf.keras.utils.get_file(os.path.basename(url), url, cache_dir=data_dir)\n",
        "  dataset = tf.data.TextLineDataset(path)\n",
        "  dataset = dataset.skip(1)\n",
        "  dataset = dataset.map(parse)\n",
        "  dataset = dataset.cache()\n",
        "  dataset = dataset.repeat()\n",
        "  if training:\n",
        "    dataset = dataset.shuffle(buffer_size=3000)\n",
        "  dataset = dataset.padded_batch(\n",
        "      batch_size, padded_shapes=((None,), (None, 256), ()))\n",
        "  # To simplify the model code, we statically set as many of the shapes that we\n",
        "  # know.\n",
        "  dataset = dataset.map(set_static_batch_shape(batch_size))\n",
        "  return dataset"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "waZ89t3DTUla",
        "slideshow": {
          "slide_type": "-"
        }
      },
      "source": [
        "To show the use of control flow, we write the RNN loop by hand, rather than using a pre-built RNN model.\n",
        "\n",
        "Note how we write the model code in Eager style, with regular `if` and `while` statements. Then, we annotate the functions with `@autograph.convert` to have them automatically compiled to run in graph mode.\n",
        "We use Keras to define the model, and we will train it using Estimators."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "9v8AJouiC44V",
        "slideshow": {
          "slide_type": "slide"
        }
      },
      "outputs": [],
      "source": [
        "@autograph.convert()\n",
        "class RnnColorbot(tf.keras.Model):\n",
        "  \"\"\"RNN Colorbot model.\"\"\"\n",
        "\n",
        "  def __init__(self):\n",
        "    super(RnnColorbot, self).__init__()\n",
        "    self.lower_cell = tf.contrib.rnn.LSTMBlockCell(256, dtype=tf.float32)\n",
        "    self.upper_cell = tf.contrib.rnn.LSTMBlockCell(128, dtype=tf.float32)\n",
        "    self.relu_layer = tf.layers.Dense(3, activation=tf.nn.relu)\n",
        "\n",
        "  def _rnn_layer(self, chars, cell, batch_size, training):\n",
        "    \"\"\"A single RNN layer.\n",
        "\n",
        "    Args:\n",
        "      chars: A Tensor of shape (max_sequence_length, batch_size, input_size)\n",
        "      cell: An object of type tf.contrib.rnn.LSTMBlockCell\n",
        "      batch_size: Int, the batch size to use\n",
        "      training: Boolean, whether the layer is used for training\n",
        "\n",
        "    Returns:\n",
        "      A Tensor of shape (max_sequence_length, batch_size, output_size).\n",
        "    \"\"\"\n",
        "    hidden_outputs = tf.TensorArray(tf.float32, 0, True)\n",
        "    state, output = cell.zero_state(batch_size, tf.float32)\n",
        "    for ch in chars:\n",
        "      cell_output, (state, output) = cell.call(ch, (state, output))\n",
        "      hidden_outputs.append(cell_output)\n",
        "    hidden_outputs = autograph.stack(hidden_outputs)\n",
        "    if training:\n",
        "      hidden_outputs = tf.nn.dropout(hidden_outputs, 0.5)\n",
        "    return hidden_outputs\n",
        "\n",
        "  def build(self, _):\n",
        "    \"\"\"Creates the model variables. See keras.Model.build().\"\"\"\n",
        "    self.lower_cell.build(tf.TensorShape((None, 256)))\n",
        "    self.upper_cell.build(tf.TensorShape((None, 256)))\n",
        "    self.relu_layer.build(tf.TensorShape((None, 128)))    \n",
        "    self.built = True\n",
        "\n",
        "\n",
        "  def call(self, inputs, training=False):\n",
        "    \"\"\"The RNN model code. Uses Eager.\n",
        "\n",
        "    The model consists of two RNN layers (made by lower_cell and upper_cell),\n",
        "    followed by a fully connected layer with ReLU activation.\n",
        "\n",
        "    Args:\n",
        "      inputs: A tuple (chars, length)\n",
        "      training: Boolean, whether the layer is used for training\n",
        "\n",
        "    Returns:\n",
        "      A Tensor of shape (batch_size, 3) - the model predictions.\n",
        "    \"\"\"\n",
        "    chars, length = inputs\n",
        "    batch_size = chars.shape[0]\n",
        "    seq = tf.transpose(chars, (1, 0, 2))\n",
        "\n",
        "    seq = self._rnn_layer(seq, self.lower_cell, batch_size, training)\n",
        "    seq = self._rnn_layer(seq, self.upper_cell, batch_size, training)\n",
        "\n",
        "    # Grab just the end-of-sequence from each output.\n",
        "    indices = (length - 1, list(range(batch_size)))\n",
        "    indices = tf.stack(indices, 1)\n",
        "    sequence_ends = tf.gather_nd(seq, indices)\n",
        "    return self.relu_layer(sequence_ends)\n",
        "\n",
        "@autograph.convert()\n",
        "def loss_fn(labels, predictions):\n",
        "  return tf.reduce_mean((predictions - labels) ** 2)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "JjK4gXFvFsf4",
        "slideshow": {
          "slide_type": "slide"
        }
      },
      "source": [
        "We will now create the model function for the custom Estimator.\n",
        "\n",
        "In the model function, we simply use the model class we defined above - that's it!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "-yso_Nx23Gy1",
        "slideshow": {
          "slide_type": "-"
        }
      },
      "outputs": [],
      "source": [
        "def model_fn(features, labels, mode, params):\n",
        "  \"\"\"Estimator model function.\"\"\"\n",
        "  chars = features['chars']\n",
        "  sequence_length = features['sequence_length']\n",
        "  inputs = (chars, sequence_length)\n",
        "\n",
        "  # Create the model. Simply using the AutoGraph-ed class just works!\n",
        "  colorbot = RnnColorbot()\n",
        "  colorbot.build(None)\n",
        "\n",
        "  if mode == tf.estimator.ModeKeys.TRAIN:\n",
        "    predictions = colorbot(inputs, training=True)\n",
        "    loss = loss_fn(labels, predictions)\n",
        "\n",
        "    learning_rate = params['learning_rate']\n",
        "    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate)\n",
        "    global_step = tf.train.get_global_step()\n",
        "    train_op = optimizer.minimize(loss, global_step=global_step)\n",
        "    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)\n",
        "\n",
        "  elif mode == tf.estimator.ModeKeys.EVAL:\n",
        "    predictions = colorbot(inputs)\n",
        "    loss = loss_fn(labels, predictions)\n",
        "\n",
        "    return tf.estimator.EstimatorSpec(mode, loss=loss)\n",
        "\n",
        "  elif mode == tf.estimator.ModeKeys.PREDICT:\n",
        "    predictions = colorbot(inputs)\n",
        "\n",
        "    predictions = tf.minimum(predictions, 1.0)\n",
        "    return tf.estimator.EstimatorSpec(mode, predictions=predictions)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "HOQfoBnHC9CP",
        "slideshow": {
          "slide_type": "-"
        }
      },
      "source": [
        "We'll create an input function that will feed our training and eval data."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 0,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          }
        },
        "colab_type": "code",
        "id": "FJZlx7yG2MP0",
        "slideshow": {
          "slide_type": "slide"
        }
      },
      "outputs": [],
      "source": [
        "def input_fn(data_dir, data_url, params, training=True):\n",
        "  \"\"\"An input function for training\"\"\"\n",
        "  batch_size = params['batch_size']\n",
        "  \n",
        "  # load_dataset defined above\n",
        "  dataset = load_dataset(data_dir, data_url, batch_size, training=training)\n",
        "\n",
        "  # Package the pipeline end in a format suitable for the estimator.\n",
        "  labels, chars, sequence_length = dataset.make_one_shot_iterator().get_next()\n",
        "  features = {\n",
        "      'chars': chars,\n",
        "      'sequence_length': sequence_length\n",
        "  }\n",
        "\n",
        "  return features, labels"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "qsvv-lzbDqXd",
        "slideshow": {
          "slide_type": "-"
        }
      },
      "source": [
        "We now have everything in place to build our custom estimator and use it for training and eval!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 107,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "height": 35
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 5454,
          "status": "ok",
          "timestamp": 1529952160455,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "2pg1AfbxBJQq",
        "outputId": "4aef3052-f7c7-4bb1-a0a2-73fef2e96efb",
        "slideshow": {
          "slide_type": "-"
        }
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Eval loss at step 100: 0.0705221\n"
          ]
        }
      ],
      "source": [
        "params = {\n",
        "    'batch_size': 64,\n",
        "    'learning_rate': 0.01,\n",
        "}\n",
        "\n",
        "train_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/train.csv\"\n",
        "test_url = \"https://raw.githubusercontent.com/random-forests/tensorflow-workshop/master/archive/extras/colorbot/data/test.csv\"\n",
        "data_dir = \"tmp/rnn/data\"\n",
        "\n",
        "regressor = tf.estimator.Estimator(\n",
        "    model_fn=model_fn,\n",
        "    params=params)\n",
        "\n",
        "regressor.train(\n",
        "    input_fn=lambda: input_fn(data_dir, train_url, params),\n",
        "    steps=100)\n",
        "eval_results = regressor.evaluate(\n",
        "    input_fn=lambda: input_fn(data_dir, test_url, params, training=False),\n",
        "    steps=2\n",
        ")\n",
        "\n",
        "print('Eval loss at step %d: %s' % (eval_results['global_step'], eval_results['loss']))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "colab_type": "text",
        "id": "zG1YAjB_cUnQ",
        "slideshow": {
          "slide_type": "slide"
        }
      },
      "source": [
        "And here's the same estimator used for inference."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 108,
      "metadata": {
        "colab": {
          "autoexec": {
            "startup": false,
            "wait_interval": 0
          },
          "height": 343
        },
        "colab_type": "code",
        "executionInfo": {
          "elapsed": 3432,
          "status": "ok",
          "timestamp": 1529952163923,
          "user": {
            "displayName": "",
            "photoUrl": "",
            "userId": ""
          },
          "user_tz": 240
        },
        "id": "dxHex2tUN_10",
        "outputId": "1ff438f2-b045-4f4e-86a0-4dae7503f6b2",
        "slideshow": {
          "slide_type": "slide"
        }
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "\u003clink rel=stylesheet type=text/css href='/nbextensions/google.colab/tabbar.css'\u003e\u003c/link\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML at 0x7fcd7222a110\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "\u003cscript src='/nbextensions/google.colab/tabbar_main.min.js'\u003e\u003c/script\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML at 0x7fcd7222a8d0\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "\u003cdiv id=\"id3\"\u003e\u003c/div\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML at 0x7fcd7222a050\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8a03307e-78a7-11e8-99f9-c8d3ffb5fbe0\"] = colab_lib.createTabBar({\"contentBorder\": [\"0px\"], \"elementId\": \"id3\", \"contentHeight\": [\"initial\"], \"tabNames\": [\"RNN Colorbot\"], \"location\": \"top\", \"initialSelection\": 0, \"borderColor\": [\"#a7a7a7\"]});\n",
              "//# sourceURL=js_dc5d7f2784"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd7222a190\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8a03307f-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n",
              "//# sourceURL=js_be7950150b"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd7222ac90\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8a033080-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n",
              "//# sourceURL=js_d0c3bd4eaa"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd7222aad0\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8a033081-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n",
              "//# sourceURL=js_f10f6eba86"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd7222aed0\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8a033082-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8a033081-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n",
              "//# sourceURL=js_ff29697179"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd7222abd0\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8a033083-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n",
              "//# sourceURL=js_ff85295dc7"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd7222ab90\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8b18d8dc-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8a033080-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n",
              "//# sourceURL=js_ed7aabfedb"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd7222a110\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8b18d8dd-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n",
              "//# sourceURL=js_c86f8feaf4"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd7222acd0\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8b18d8de-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n",
              "//# sourceURL=js_4d0fde6662"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd7222ae50\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8b18d8df-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8de-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n",
              "//# sourceURL=js_3f66d52720"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd7222a210\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8b18d8e0-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n",
              "//# sourceURL=js_375f5ae6d7"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd7222a310\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAQwAAAENCAYAAAD60Fs2AAAABHNCSVQICAgIfAhkiAAAAAlwSFlz\nAAALEgAACxIB0t1+/AAABTFJREFUeJzt3C+LV30eh/HP6EZvbP4ZJmkXDA6oQdZRMIhYLIKCMGVA\nyyaLT2ERLMqEDfoUFA2y3WpRrOKoSUSECePcYUEWdsN1OzfOyr5e8ZwT3unie34cfgvb29vbAxDs\n2e0BwK9DMIBMMIBMMIBMMIBMMIBMMPipXrx4MWfOnNntGfwgweCnW1hY2O0J/CDBYEe2trZ2ewI/\nkWDwh509e3bW19fn0qVLc/z48dnY2Jhbt27NyZMn59y5c/Pw4cPvz25ubs7t27dneXl5Ll68OC9f\nvtzF5ezUX3Z7AL+mJ0+ezPr6+uzfv3+uXr0658+fn7t3787GxsbcuHFjjhw5MqdPn5579+7N27dv\n5/nz5/P169dZXV3d7ensgBMGP+T69etz8ODBef369Xz69GnW1tZm7969s7S0NFeuXJnHjx/PzMzT\np09nbW1tfvvttzl48OBcu3Ztl5ezE04Y/JBDhw7NzMy7d+/mw4cPs7y8PDMz29vb8+3btzlx4sTM\nzHz8+PH7szMzi4uLP38sfxrBYEcOHz48S0tL8+zZs/96/8CBA7OxsTFHjx6dmX8Fhl+XVxJ25Nix\nY7Nv375ZX1+fzc3N2dramjdv3nz/cfPChQvz4MGD+fz587x//34ePXq0y4vZCcHgD/v37yj27Nkz\n9+/fn1evXs3KysqcOnVq7ty5M1++fJmZmZs3b87i4uKsrKzM6urqXL58ebdm8ydY8Ac6QOWEAWSC\nAWSCAWSCAWT/s99h/P3GX3d7Avxf+9s//vkf15wwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEww\ngEwwgEwwgEwwgGxhe3t7e7dHAL8GJwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwg\nEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwgEwwg+x1QoZHG4XIe4gAAAABJRU5ErkJggg==\n",
            "text/plain": [
              "\u003cmatplotlib.figure.Figure at 0x7fcd0d02dc90\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3",
              "user_output"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8b18d8e1-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8dd-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n",
              "//# sourceURL=js_34b0509660"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd08e6e850\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8b18d8e2-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.getActiveOutputArea();\n",
              "//# sourceURL=js_518a0f26fe"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd08e6ec90\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8b18d8e3-78a7-11e8-99f9-c8d3ffb5fbe0\"] = document.querySelector(\"#id3_content_0\");\n",
              "//# sourceURL=js_17eb3ff612"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd08e6eb50\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8b18d8e4-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8e3-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n",
              "//# sourceURL=js_99da807c8e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd08e6eb90\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8b18d8e5-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"id3\"].setSelectedTabIndex(0);\n",
              "//# sourceURL=js_dee01cb4b6"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd08e6e610\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "text/html": [
              "\u003cdiv class=id_853612217 style=\"margin-right:10px; display:flex;align-items:center;\"\u003e\u003cspan style=\"margin-right: 3px;\"\u003e\u003c/span\u003e\u003c/div\u003e"
            ],
            "text/plain": [
              "\u003cIPython.core.display.HTML at 0x7fcd7222aa10\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3",
              "user_output"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8b18d8e6-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 span\");\n",
              "//# sourceURL=js_8c378be329"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd08e6e990\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3",
              "user_output"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8b18d8e7-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8e6-78a7-11e8-99f9-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n",
              "//# sourceURL=js_f0b946600c"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd08e6e310\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3",
              "user_output"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8b18d8e9-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 input\");\n",
              "//# sourceURL=js_9e21b1373a"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd08e6ea90\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3",
              "user_output"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8b18d8ea-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8e9-78a7-11e8-99f9-c8d3ffb5fbe0\"].remove();\n",
              "//# sourceURL=js_a7764968c6"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd08e6e5d0\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3",
              "user_output"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8b18d8eb-78a7-11e8-99f9-c8d3ffb5fbe0\"] = jQuery(\".id_853612217 span\");\n",
              "//# sourceURL=js_74279d3ff0"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd08e6e890\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3",
              "user_output"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8b18d8ec-78a7-11e8-99f9-c8d3ffb5fbe0\"] = window[\"8b18d8eb-78a7-11e8-99f9-c8d3ffb5fbe0\"].text(\"Give me a color name (or press 'enter' to exit): \");\n",
              "//# sourceURL=js_82b6c34cdb"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd08e6e8d0\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3",
              "user_output"
            ]
          },
          "output_type": "display_data"
        },
        {
          "data": {
            "application/javascript": [
              "window[\"8b18d8ed-78a7-11e8-99f9-c8d3ffb5fbe0\"] = google.colab.output.setActiveOutputArea(window[\"8b18d8e2-78a7-11e8-99f9-c8d3ffb5fbe0\"]);\n",
              "//# sourceURL=js_ff6144734a"
            ],
            "text/plain": [
              "\u003cIPython.core.display.Javascript at 0x7fcd08e6e8d0\u003e"
            ]
          },
          "metadata": {
            "tags": [
              "id3_content_0",
              "outputarea_id3"
            ]
          },
          "output_type": "display_data"
        }
      ],
      "source": [
        "def predict_input_fn(color_name):\n",
        "  \"\"\"An input function for prediction.\"\"\"\n",
        "  _, chars, sequence_length = parse(color_name)\n",
        "\n",
        "  # We create a batch of a single element.\n",
        "  features = {\n",
        "      'chars': tf.expand_dims(chars, 0),\n",
        "      'sequence_length': tf.expand_dims(sequence_length, 0)\n",
        "  }\n",
        "  return features, None\n",
        "\n",
        "\n",
        "def draw_prediction(color_name, pred):\n",
        "  pred = pred * 255\n",
        "  pred = pred.astype(np.uint8)\n",
        "  plt.axis('off')\n",
        "  plt.imshow(pred)\n",
        "  plt.title(color_name)\n",
        "  plt.show()\n",
        "\n",
        "\n",
        "def predict_with_estimator(color_name, regressor):\n",
        "  predictions = regressor.predict(\n",
        "      input_fn=lambda:predict_input_fn(color_name))\n",
        "  pred = next(predictions)\n",
        "  predictions.close()\n",
        "  pred = np.minimum(pred, 1.0)\n",
        "  pred = np.expand_dims(np.expand_dims(pred, 0), 0)\n",
        "\n",
        "  draw_prediction(color_name, pred)\n",
        "\n",
        "tb = widgets.TabBar([\"RNN Colorbot\"])\n",
        "while True:\n",
        "  with tb.output_to(0):\n",
        "    try:\n",
        "      color_name = six.moves.input(\"Give me a color name (or press 'enter' to exit): \")\n",
        "    except (EOFError, KeyboardInterrupt):\n",
        "      break\n",
        "  if not color_name:\n",
        "    break\n",
        "  with tb.output_to(0):\n",
        "    tb.clear_tab()\n",
        "    predict_with_estimator(color_name, regressor)\n",
        "  "
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [],
      "default_view": {},
      "last_runtime": {
        "build_target": "",
        "kind": "local"
      },
      "name": "RNN Colorbot using Keras and Estimators",
      "version": "0.3.2",
      "views": {}
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
